
import torch
from torch import nn
import device



class neuralEQRNN(nn.Module):
	def __init__(self, hiddenSize, numLayers, bidir=False):
		super().__init__()
		self.rnn = nn.RNN(1, hiddenSize, numLayers, batch_first=True, bidirectional=bidir, dropout=0, nonlinearity='relu')
		if (bidir):
			self.fc = nn.Linear(hiddenSize*2,1)
		else:
			self.fc = nn.Linear(hiddenSize,1)

	def forward(self,x,hIn):
		# x: batchSize, seqLength, inputSize
		out, hOut = self.rnn(x,hIn)
		# out: batchSize, seqLength, hiddenSize
		# hOut: D*numLayers, batchSize, hiddenSize
		out = self.fc(out)
		# out: batchSize, seqLength, 1
		return out, hOut

if __name__ == '__main__':
	pass
